import torch
import torch.distributed as dist
import torch.nn.functional as F

try:
    import fused_mix_prec_layer_norm_cuda
except:
    fused_mix_prec_layer_norm_cuda = None

try:
    import fused_weight_gradient_mlp_cuda

    _grad_accum_fusion_available = True
except ImportError:
    _grad_accum_fusion_available = False

from colossalai.shardformer.layer._operation import _reduce
import numpy as np

def reduce_backward_QT(input_, process_group):
    return _ReduceBackward_QT.apply(input_, process_group)

class _ReduceBackward_QT(torch.autograd.Function):
    """
    All-reduce the input from the model parallel region.

    Args:
        input_: input matrix.
        parallel_mode: parallel mode.
    """
    
    @staticmethod
    def quantize(tensor: torch.Tensor, min_val, max_val):
        """Quantize the tensor to 8-bit symmetric quantization using torch.int8."""
        abs_max = max(min_val.abs(), max_val.abs())
        scale = abs_max / 127  # 127 is the max value for symmetric quantization in int8
        quantized = (tensor / scale).round().clamp(-127, 127).to(torch.int8)
        return quantized, scale    

    @staticmethod
    def dequantize(tensor, scale):
        """Dequantize the tensor after communication."""
        return (tensor.to(torch.float32) * scale).to(torch.bfloat16)

    @staticmethod
    def forward(ctx, input_, process_group):
        ctx.process_group = process_group
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.to(torch.float32)
        min_val, max_val = grad_output.min(), grad_output.max()
        dist.all_reduce(min_val, group=ctx.process_group)
        dist.all_reduce(max_val, group=ctx.process_group)
        
        grad_output_quant, scale = _ReduceBackward_QT.quantize(grad_output, min_val, max_val)
        # print(grad_output_quant.shape)
        # exit()
        _reduce(grad_output_quant, ctx.process_group)
        grad_output = _ReduceBackward_QT.dequantize(grad_output_quant, scale)
        return grad_output, None